{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Dash cuML UMAP Colab Demo.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "mount_file_id": "1AXCIxBqWQOmn7WtxGU9YkTuSU_ZtAPSK",
      "authorship_tag": "ABX9TyMTQnvPWAXym8R5eBfjyMUU"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OGIcIsTjjQoH",
        "colab_type": "text"
      },
      "source": [
        "To start this Jupyter Dash app, please run all the cells below. Then, click on the **temporary** URL at the end of the last cell to open the app.\n",
        "\n",
        "By running this notebook, you agree to the terms and conditions of Kaggle.com, as well as the licenses specified in the original dataset: https://www.kaggle.com/mlg-ulb/creditcardfraud"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GNXpzpMaf8wL",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Install RAPIDS\n",
        "!git clone https://github.com/rapidsai/rapidsai-csp-utils.git\n",
        "!bash rapidsai-csp-utils/colab/rapids-colab.sh stable\n",
        "\n",
        "import sys, os\n",
        "\n",
        "dist_package_index = sys.path.index('/usr/local/lib/python3.6/dist-packages')\n",
        "sys.path = sys.path[:dist_package_index] + ['/usr/local/lib/python3.6/site-packages'] + sys.path[dist_package_index:]\n",
        "sys.path\n",
        "exec(open('rapidsai-csp-utils/colab/update_modules.py').read(), globals())"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JNNrZ3LWgJL-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install -q jupyter-dash==0.3.0rc1 dash-bootstrap-components"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fTGoosFCkMWk",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!wget -nc https://plotly-tutorials.s3-us-west-1.amazonaws.com/dash-sample-apps/creditcard.csv"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tZ3FKF6ZirB-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import os\n",
        "import time\n",
        "\n",
        "import cudf\n",
        "import cuml\n",
        "import cupy as cp\n",
        "import dash\n",
        "import dash_html_components as html\n",
        "import dash_core_components as dcc\n",
        "import dash_bootstrap_components as dbc\n",
        "from dash.dependencies import Input, Output, State\n",
        "from jupyter_dash import JupyterDash\n",
        "import plotly.express as px"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MPSBHkHei0zQ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Load CSV into a cudf\n",
        "data_dir = os.environ.get(\"DATA_DIR\", \"\")\n",
        "path = os.path.join(data_dir, \"creditcard.csv\")\n",
        "gdf = cudf.read_csv(path)\n",
        "gdf.Time = gdf.Time / 3600\n",
        "gdf.loc[gdf.Amount > 500, \"Amount\"] = 500"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1MwOg5tXi2Hp",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Define app\n",
        "app = JupyterDash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])\n",
        "server = app.server\n",
        "\n",
        "\n",
        "controls = dbc.Row(\n",
        "    [\n",
        "        dbc.Col(\n",
        "            dbc.FormGroup(\n",
        "                [\n",
        "                    dbc.Label(\"Time since start (h)\"),\n",
        "                    dcc.RangeSlider(\n",
        "                        id=\"slider-hours\",\n",
        "                        min=0,\n",
        "                        max=50,\n",
        "                        step=1,\n",
        "                        value=[20, 30],\n",
        "                        marks={i: str(i) for i in range(0, 51, 10)},\n",
        "                    ),\n",
        "                ]\n",
        "            ),\n",
        "            md=6,\n",
        "        ),\n",
        "        dbc.Col(\n",
        "            dbc.FormGroup(\n",
        "                [\n",
        "                    dbc.Label(\"Transaction Amount ($)\"),\n",
        "                    dcc.RangeSlider(\n",
        "                        id=\"slider-amount\",\n",
        "                        min=0,\n",
        "                        max=500,\n",
        "                        step=5,\n",
        "                        value=[200, 300],\n",
        "                        marks={i: str(i) for i in range(0, 501, 100)},\n",
        "                    ),\n",
        "                ]\n",
        "            ),\n",
        "            md=6,\n",
        "        ),\n",
        "    ],\n",
        ")\n",
        "\n",
        "\n",
        "# Define Layout\n",
        "app.layout = dbc.Container(\n",
        "    fluid=True,\n",
        "    children=[\n",
        "        html.H1(\"Dash cuML UMAP Demo\"),\n",
        "        html.Hr(),\n",
        "        dbc.Card(controls, body=True),\n",
        "        dcc.Graph(id=\"graph-umap\", style={\"height\": \"70vh\", \"max-height\": \"90vw\"}),\n",
        "        html.Div(id=\"output-info\"),\n",
        "    ],\n",
        "    style={\"max-width\": \"960px\", \"margin\": \"auto\"},\n",
        ")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hyjUEsglkrue",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "@app.callback(\n",
        "    [Output(\"graph-umap\", \"figure\"), Output(\"output-info\", \"children\")],\n",
        "    [Input(\"slider-amount\", \"value\"), Input(\"slider-hours\", \"value\"),],\n",
        ")\n",
        "def update_graph(amt, hrs):\n",
        "    t0 = time.time()\n",
        "    # First, filter based on the slider values\n",
        "    time_mask = (gdf.Time >= hrs[0]) & (gdf.Time <= hrs[1])\n",
        "    amount_mask = (gdf.Amount >= amt[0]) & (gdf.Amount <= amt[1])\n",
        "    filt_df = gdf.loc[time_mask & amount_mask]\n",
        "\n",
        "    # Then, select the features and train a UMAP model with cuML\n",
        "    features = filt_df.loc[:, \"V1\":\"V28\"].values\n",
        "    reducer = cuml.UMAP()\n",
        "    embedding = reducer.fit_transform(features)\n",
        "\n",
        "    # Convert the embedding back to numpy\n",
        "    embedding = cp.asnumpy(embedding)\n",
        "    amount = cp.asnumpy(filt_df.Amount.values.round(2))\n",
        "\n",
        "    # Create a plotly.express scatter plot\n",
        "    fig = px.scatter(\n",
        "        x=embedding[:, 0],\n",
        "        y=embedding[:, 1],\n",
        "        color=amount,\n",
        "        labels={\"color\": \"Amount ($)\"},\n",
        "        title=\"UMAP projection of credit card transactions\",\n",
        "    )\n",
        "\n",
        "    t1 = time.time()\n",
        "    out_msg = f\"Projected {embedding.shape[0]} transactions in {t1-t0:.2f}s.\"\n",
        "    alert = dbc.Alert(out_msg, color=\"success\", dismissable=True)\n",
        "\n",
        "    return fig, alert"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-vopNmMpkr92",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "app.run_server(mode='external')"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}